import os
import click
import torch
import torch.distributed as dist
import yaml
from ema_pytorch import EMA
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm

import utils.graph_lib
from utils.samplers import get_sampler
import wandb
from models.model_utils import get_model, get_preconditioned_model
from utils.datasets import get_dataset 
from utils.losses import get_loss
from utils.misc import dotdict
from utils.optimizers import WarmUpScheduler

# This makes training on A100s faster
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

def init_wandb(opts):
    wandb.init(
        # set the wandb project where this run will be logged
        project='Discrete CFG',
        name= f'{opts.model}-{opts.dataset}-{opts.context_len}',
        tags= ['training',opts.dataset],
        # # track hyperparameters and run metadata
        config=opts
    )

@click.command()
@click.option('--dataset',type=click.Choice(['vector-disjoint', 'vector-intersection', 'matrix-intersection', 'matrix-disjoint', 'gaussian-5d']), default='vector-disjoint')
@click.option('--context_len',type=int, default=50)
@click.option('--model',type=click.Choice(['radd']), default='radd')
@click.option('--optimizer',type=click.Choice(['adam','adamw']), default='adam')
@click.option('--lr', type=float, default=1e-5)
@click.option('--ema_beta', type=float, default=.9999)
@click.option('--batch_size', type=int, default=256)
@click.option('--log_rate',type=int,default=5000)
@click.option('--num_iters',type=int,default=30000)
@click.option('--warmup_iters',type=int,default=2500)
@click.option('--num_workers',type=int,default=2)
@click.option('--seed',type=int,default=42)
@click.option('--dir',type=str)
@click.option('--data_config_path',type=str, default='configs/gmm_configs/3_modes.yaml')
@click.option('--net_config_path',type=str, default='configs/toy_net.yaml')
@click.option('--load_checkpoint',type=str, help='Directory where we can find the desired checkpoints')
@click.option('--enable_wandb', is_flag=True, default=False)
def training(**opts):
    opts = dotdict(opts)
    batch_size = opts.batch_size
    
    dist.init_process_group('nccl')
    world_size = dist.get_world_size()
    assert batch_size % world_size == 0, "Batch size must be divisible by world size."
    rank = dist.get_rank()
    device = rank % torch.cuda.device_count()
    seed = opts.seed * world_size + rank
    torch.manual_seed(seed)
    torch.cuda.set_device(device)
    print(f"Starting rank={rank}, seed={seed}, world_size={world_size}.")

    
    net_opts = dotdict(yaml.safe_load(open(opts.net_config_path)))
        
    wandb_enabled = opts.enable_wandb and rank == 0 # We only want to log once
    if wandb_enabled:
        init_wandb(opts)
        wandb.config.update(net_opts)
    
    dataset = get_dataset(opts.dataset, opts.context_len, batch_size, config_path=opts.data_config_path) 

    vocab_size = dataset.vocab_size
    context_len = dataset.context_len
    graph = utils.graph_lib.Absorbing(vocab_size)

    model = get_model(opts.model,vocab_size + 1, context_len, net_opts)
    ema = EMA(model, beta=opts.ema_beta)
    model = get_preconditioned_model(model,graph).to(device)
    ema = get_preconditioned_model(ema,graph).to(device)
    opt = torch.optim.AdamW(model.parameters(),lr=opts.lr,)
    scheduler = WarmUpScheduler(opt, opts.warmup_iters)
    scaler = torch.amp.GradScaler(device)
    
    
    start_iter = 0
    dist.barrier()
    if opts.load_checkpoint is not None:
        start_iter = load_checkpoint(opts, rank, device, model, ema, opt, scheduler)

    dist.barrier()
    
    ema.eval()
    model.train()
    model = DDP(model)
    
    if rank == 0:
        print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)//1e6} M")
    
    if not os.path.exists(opts.dir) and rank == 0:
        os.makedirs(opts.dir)

    num_iters = opts.num_iters

    loss_fn = get_loss(graph)
    sampling_fn = get_sampler(graph, device)
    
    training_iter = start_iter
    log_rate = opts.log_rate
    pbar = tqdm(iter(dataset),total=num_iters,leave=False) if rank == 0 else iter(dataset)
    for data, cond in pbar:
        if training_iter > num_iters:
            break
        
        data = data.to(device=device)
        cond = cond.to(device=device).long()
        
        opt.zero_grad()
        
        loss = loss_fn(model, data, cond)

        scaler.scale(loss).backward()
        scaler.unscale_(opt)

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
        
        for param in model.parameters():
            if param.grad is not None:
                torch.nan_to_num(param.grad, nan=0, posinf=0, neginf=0, out=param.grad)
        
        scaler.step(opt)
        scaler.update()
        ema.net.update()
        scheduler.step()
        
        training_iter += 1
        
        dist.all_reduce(loss, op=dist.ReduceOp.SUM)
        loss = loss.detach().item()/world_size
        
        
        if rank == 0:
            pbar.set_description(f'Iter {training_iter} --- Loss : {loss :6.2f}')
                
        if wandb_enabled:
            wandb.log({
            'loss': loss/world_size})
        dist.barrier()
        # Evaluate sample accuracy
        if training_iter%log_rate == 0 or training_iter == num_iters:
            path = os.path.join(opts.dir, f'itr_{training_iter}/')
            os.makedirs(path, exist_ok=True)
            if rank == 0:
                save_ckpt(model, ema, opt, scheduler, os.path.join(path, 'snapshot.pt'))
            model.eval()
            dist.barrier()
            
            n_samples = 5000
            cond = dataset.generate_cond(n_samples).to(device=device)
            new_sample = sampling_fn(model,(n_samples,context_len),cond, 100)
            new_sample_ema = sampling_fn(ema,(n_samples,context_len),cond, 100)

            dataset.plot_samples(new_sample, os.path.join(path,f'sample_{training_iter}.png'))
            dataset.plot_samples(new_sample_ema, os.path.join(path,f'sample_ema_{training_iter}.png'))


        dist.barrier()                            
        model.train()

    if rank == 0:
        save_ckpt(model, ema, opt, scheduler, os.path.join(opts.dir, 'final_checkpoint.pt'))

    dist.barrier()
    if wandb_enabled:
        wandb.finish()
    dist.destroy_process_group()

def load_checkpoint(opts, rank, device, model, ema, opt, scheduler):
    print(f'Loading checkpoint from {opts.load_checkpoint} in rank {rank}')
    snapshot = torch.load(os.path.join(opts.load_checkpoint), weights_only=True)
    model.net.load_state_dict(snapshot['model'],strict=False)
    ema.net.ema_model.load_state_dict(snapshot['ema'],strict=True)
    ema.net.online_model.load_state_dict(snapshot['model'],strict=True)
    opt.load_state_dict(snapshot['optimizer'])
    scheduler.load_state_dict(snapshot['scheduler'])
        
    start_iter = scheduler.last_epoch
    ema.net.initted = torch.tensor(False,device=device)
    ema.net.step = torch.tensor(start_iter - 1, device=device)
    return start_iter

def save_ckpt(model, ema, opt, scheduler, path):
    snapshot = {
                    'model': model.module.net.state_dict(),
                    'ema': ema.net.ema_model.state_dict(),
                    'optimizer': opt.state_dict(),
                    'scheduler': scheduler.state_dict()
                }
    torch.save(snapshot,path)


if __name__ == '__main__':
    training()